//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/Transforms.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"

using namespace mlir;
using namespace mlir::linalg;

/// Return `true` if the `result` of an operation `genericOp` is dead.
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
  if (!result.use_empty())
    return false;
  // If out operand not used in payload, we can drop it.
  OpOperand *outputOpOperand =
      genericOp.getDpsInitOperand(result.getResultNumber());
  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
    return true;

  // The out operand that is part of a payload can be dropped if
  // these conditions are met:
  // - Result from out operand is dead.
  // - User of arg is yield.
  // - outArg data is not being used by other outArgs.

  // Check block arg and cycle from out operand has a single use.
  BlockArgument outputArg =
      genericOp.getRegionOutputArgs()[result.getResultNumber()];
  if (!outputArg.hasOneUse())
    return false;
  Operation *argUserOp = *outputArg.user_begin();

  // Check argUser has no other use.
  if (!argUserOp->use_empty())
    return false;

  // Check that argUser is a yield.
  auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
  if (!yieldOp)
    return false;

  // Check outArg data is not being used by other outArgs.
  if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
    return false;

  return true;
}

//===---------------------------------------------------------------------===//
// Helper methods for operand deduplication and dead results elimination
//===---------------------------------------------------------------------===//

// Deduplicate input operands, and return the
// - Mapping from operand position in the original op, to operand position in
// the canonicalized op.
// - The preserved input operands list (by reference).
llvm::SmallDenseMap<unsigned, unsigned> static deduplicateInputOperands(
    GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
    SmallVector<Value> &newInputOperands,
    SmallVector<AffineMap> &newIndexingMaps) {
  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
  for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
    OpOperand *inputOpOperand = en.value();
    // Check if operand is dead and if dropping the indexing map makes the
    // loops to shape computation invalid.
    if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
      // Add the current operands to the list of potentially droppable
      // operands. If it cannot be dropped, this needs to be popped back.
      droppedOpOperands.push_back(inputOpOperand);
      if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
        continue;
      droppedOpOperands.pop_back();
    }

    // Check if this operand is a duplicate.
    AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
    auto it =
        dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
    if (it != dedupedInputs.end()) {
      origToNewPos[en.index()] = it->second;
      droppedOpOperands.push_back(inputOpOperand);
      continue;
    }

    // This is a preserved argument.
    origToNewPos[en.index()] = newInputOperands.size();
    dedupedInputs[{inputOpOperand->get(), indexingMap}] =
        newInputOperands.size();
    newInputOperands.push_back(inputOpOperand->get());
    newIndexingMaps.push_back(indexingMap);
  }
  return origToNewPos;
}

// Deduplicate output operands, and return the
// - Mapping from operand position in the original op, to operand position in
// the canonicalized op.
// - The preserved output operands list (by reference).
llvm::SmallDenseMap<unsigned, unsigned> static deduplicateOutputOperands(
    GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
    SmallVector<Value> &newOutputOperands,
    SmallVector<AffineMap> &newIndexingMaps, bool removeOutputs) {
  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
  llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
      dedupedOutpts;
  // If the op doesn't have tensor semantics or outputs should not be removed,
  // keep all the outputs as preserved.
  if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
    for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
      origToNewPos[en.index()] = newOutputOperands.size();
      newOutputOperands.push_back(en.value().get());
      newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
    }
    return origToNewPos;
  }
  // Output argument can be dropped if the result has
  // - no users, and
  // - it is not used in the payload, and
  // - the corresponding indexing maps are not needed for loop bound
  //   computation.
  auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
  for (const auto &outputOpOperand :
       llvm::enumerate(genericOp.getDpsInitsMutable())) {
    OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
    AffineMap indexingMap =
        genericOp.getMatchingIndexingMap(&outputOpOperand.value());
    auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
                               yieldOp->getOperand(outputOpOperand.index()));
    if (isResultValueDead(genericOp, result)) {
      // Check if the opoperand can be dropped without affecting loop
      // bound computation. Add the operand to the list of dropped op
      // operand for checking. If it cannot be dropped, need to pop the
      // value back.
      droppedOpOperands.push_back(&outputOpOperand.value());
      if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
        continue;
      }
      droppedOpOperands.pop_back();
    }

    if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
      // The out operand can also be dropped if it is computed redundantly
      // by another result, the conditions for that are
      // - The same operand is used as the out operand
      // - The same indexing map is used
      // - The same yield value is used.
      auto it = dedupedOutpts.find(key);
      if (it != dedupedOutpts.end()) {
        origToNewPos[outputOpOperand.index()] = it->second;
        droppedOpOperands.push_back(&outputOpOperand.value());
        continue;
      }
    }

    origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
    dedupedOutpts[key] = newOutputOperands.size();
    newOutputOperands.push_back(outputOpOperand.value().get());
    newIndexingMaps.push_back(
        genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
  }
  return origToNewPos;
}

// Populate the body of the canonicalized operation.
static void populateOpPayload(
    GenericOp genericOp, GenericOp newOp,
    const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
    const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
    RewriterBase &rewriter) {
  // Merge the body of the original op with the new op.
  Block *newOpBlock = &newOp.getRegion().front();
  assert(newOpBlock->empty() && "expected new op to have an empty payload");
  Block *origOpBlock = &genericOp.getRegion().front();
  SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);

  // Replace all arguments in the original op, with arguments from the
  // canonicalized op.
  auto updateReplacements =
      [&](SmallVector<OpOperand *> &origOperands,
          SmallVector<OpOperand *> &newOperands,
          const llvm::SmallDenseMap<unsigned, unsigned> &map) {
        for (const auto &origOperand : llvm::enumerate(origOperands)) {
          auto it = map.find(origOperand.index());
          if (it == map.end())
            continue;
          OpOperand *newOperand = newOperands[it->second];
          replacements[origOperand.value()->getOperandNumber()] =
              newOpBlock->getArgument(newOperand->getOperandNumber());
        }
      };

  SmallVector<OpOperand *> origInputOperands = genericOp.getDpsInputOperands();
  SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
  updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);

  SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
      genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
  SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
      newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
  updateReplacements(origOutputOperands, newOutputOperands,
                     origOutsToNewOutsPos);

  // Drop the unused yield args.
  if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
    OpBuilder::InsertionGuard g(rewriter);
    YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
    rewriter.setInsertionPoint(origYieldOp);

    SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
    for (const auto &yieldOpOperands :
         llvm::enumerate(origYieldOp.getValues())) {
      auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
      if (it == origOutsToNewOutsPos.end())
        continue;
      newYieldVals[it->second] = yieldOpOperands.value();
    }
    rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
  }

  rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
}

FailureOr<linalg::GenericOp>
mlir::linalg::deduplicateOperandsAndRemoveDeadResults(
    RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs) {
  // Create a map from argument position in the original op to the argument
  // position in the new op. If the argument is dropped it wont have an entry.
  SmallVector<OpOperand *> droppedOpOperands;

  // Information needed to build the new op.
  SmallVector<Value> newInputOperands, newOutputOperands;
  SmallVector<AffineMap> newIndexingMaps;

  // Gather information about duplicate input operands.
  llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
      deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
                               newIndexingMaps);

  // Gather information about the dropped outputs.
  llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
      deduplicateOutputOperands(genericOp, droppedOpOperands, newOutputOperands,
                                newIndexingMaps, removeOutputs);

  // Check if there is any change to operands.
  if (newInputOperands.size() + newOutputOperands.size() ==
      genericOp->getNumOperands())
    return genericOp;

  // Create the new op with the body being empty.
  Location loc = genericOp.getLoc();
  SmallVector<Type> newResultTypes;
  for (Value v : newOutputOperands)
    if (isa<TensorType>(v.getType()))
      newResultTypes.push_back(v.getType());
  auto newOp = GenericOp::create(
      rewriter, loc, newResultTypes, newInputOperands, newOutputOperands,
      rewriter.getAffineMapArrayAttr(newIndexingMaps),
      genericOp.getIteratorTypes(), genericOp.getDocAttr(),
      genericOp.getLibraryCallAttr(),
      [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
        return;
      });
  // Copy over unknown attributes. They might be load bearing for some flow.
  ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
  for (NamedAttribute kv : genericOp->getAttrs())
    if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
      newOp->setAttr(kv.getName(), kv.getValue());

  // Fix up the payload of the canonicalized operation.
  populateOpPayload(genericOp, newOp, origInsToNewInsPos, origOutsToNewOutsPos,
                    rewriter);

  // Replace all live uses of the op.
  SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
  for (const auto &result : llvm::enumerate(genericOp.getResults())) {
    auto it = origOutsToNewOutsPos.find(result.index());
    if (it == origOutsToNewOutsPos.end())
      continue;
    replacementsVals[result.index()] = newOp.getResult(it->second);
  }
  rewriter.replaceOp(genericOp, replacementsVals);
  return newOp;
}

namespace {

struct DeduplicateAndRemoveDeadOperandsAndResults
    : public OpRewritePattern<GenericOp> {
  DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
                                             bool removeOutputs)
      : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}

  LogicalResult matchAndRewrite(GenericOp genericOp,
                                PatternRewriter &rewriter) const override {
    FailureOr<GenericOp> newOp = deduplicateOperandsAndRemoveDeadResults(
        rewriter, genericOp, removeOutputs);
    if (failed(newOp) || newOp.value() == genericOp) {
      return rewriter.notifyMatchFailure(
          genericOp, "failed to dedup operands/remove dead results");
    }
    return success();
  }

private:
  /// If unset, outputs are not modified by this pattern.
  bool removeOutputs;
};

/// Remove unused cycles.
/// We can remove unused cycle within a payload of generic region
/// if these conditions are met:
/// - Result from out operand is dead.
/// - Block arg from out operand has a single use in the %cycle
/// instruction.
/// - Cycle has a single use and it is in yield.
struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
  using OpRewritePattern<GenericOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(GenericOp genericOp,
                                PatternRewriter &rewriter) const override {

    // If the op doesnt have tensor semantics, preserve the outputs as is.
    if (!genericOp.hasPureTensorSemantics())
      return failure();

    bool hasRemovedCycles = false;
    // Iterate over output operands and remove any unused cycles.
    for (const auto &outputOpOperand :
         llvm::enumerate(genericOp.getDpsInits())) {

      // Check that result from out operand is dead.
      Value result = genericOp.getResult(outputOpOperand.index());
      if (!result.use_empty())
        continue;

      // Check that outputArg has one use in cycle.
      BlockArgument outputArg =
          genericOp.getRegionOutputArgs()[outputOpOperand.index()];
      if (!outputArg.hasOneUse())
        continue;

      // Check cycle has at most one use.
      Operation *cycleOp = *outputArg.user_begin();
      if (!cycleOp->hasOneUse())
        continue;

      // Check that the cycleUser is a yield.
      Operation *cycleUserOp = *cycleOp->user_begin();
      if (!isa<linalg::YieldOp>(cycleUserOp))
        continue;

      // Check that argIndex matches yieldIndex, else data is being used.
      if (cycleUserOp->getOperand(outputOpOperand.index()) !=
          cycleOp->getResult(0))
        continue;

      // Directly replace the cycle with the blockArg such that
      // Deduplicate pattern can eliminate it along with unused yield.
      rewriter.replaceOp(cycleOp, outputArg);
      rewriter.modifyOpInPlace(genericOp, [] {});
      hasRemovedCycles = true;
    }

    if (hasRemovedCycles) {
      return success();
    }

    return failure();
  }
};

/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
/// ```
/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
/// ^bb0(%in0, %in1, %in2, %in3, %out1)
/// ```
/// Assuming that all %a and %b have the same index map:
/// * All uses of %in0 and %in2 are replaced with %out1
/// * All uses of %in1 are replaced with %in3
/// This pattern can enable additional canonicalizations: In the above example,
/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
/// can be folded away. This pattern does not modify uses of output block args.
struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
  using OpRewritePattern<GenericOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(GenericOp genericOp,
                                PatternRewriter &rewriter) const override {
    // Find replacement bbArgs for all input bbArg.
    DenseMap<int, int> replacements;
    for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
      // Skip bbArgs that have no uses.
      if (genericOp.getBody()->getArgument(i).getUses().empty())
        continue;
      // Find replacement bbArg. This can be an input or an output bbArg.
      for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
        if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
            genericOp.getIndexingMapsArray()[i] ==
                genericOp.getIndexingMapsArray()[j]) {
          replacements[i] = j;
          break;
        }
      }
    }

    // Stop here if no replacements were found.
    if (replacements.empty())
      return failure();

    // Rewrite the op.
    rewriter.modifyOpInPlace(genericOp, [&]() {
      for (auto [before, after] : replacements) {
        BlockArgument bbArg = genericOp.getBody()->getArgument(before);
        BlockArgument replacement = genericOp.getBody()->getArgument(after);
        rewriter.replaceAllUsesWith(bbArg, replacement);
      }
    });

    return success();
  }
};

} // namespace

void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
    RewritePatternSet &patterns) {
  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
      patterns.getContext(), /*removeOutputs=*/true);
  patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
}

void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
    RewritePatternSet &patterns) {
  patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
      patterns.getContext(), /*removeOutputs=*/false);
  patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
}
